"""
层和块
花式自行定义神经网络
"""

import tensorflow as tf

class MLP(tf.keras.Model):
    # 用模型参数声明层，这里，我们声明两个全连接层
    def __init__(self):
        # 调用MLP的父类Model的构造函数执行必要的初始化。
        # 这样，在类实例化时也可以指定其他函数参数，例如模型参数params（稍后将介绍）
        super().__init__()
        # Hiddenlayer
        self.hidden = tf.keras.layers.Dense(units=256, activation=tf.nn.relu)
        self.out = tf.keras.layers.Dense(units=10)

    # 定义模型的前向传播，即如何根据输入X返回所需的模型输出
    def call(self, X):
        # hidden的计算结果作为参数传给out层
        return self.out(self.hidden(X))

class MySequential(tf.keras.Model):
    def __init__(self, *args):
        super().__init__()
        self.modules = []
        for block in args:
            # 这里，block是tf.keras.layers.Layer子类的一个实例
            self.modules.append(block)

    def call(self, X):
        for module in self.modules:
            X = module(X)
        return X

class FixedHiddenMLP(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.flatten = tf.keras.layers.Flatten()
        # 使用tf.constant函数创建的随机权重参数在训练期间不会更新（即为常量参数）
        self.rand_weight = tf.constant(tf.random.uniform((20, 20)))
        self.dense = tf.keras.layers.Dense(20, activation=tf.nn.relu)

    def call(self, inputs):
        X = self.flatten(inputs)
        # 使用创建的常量参数以及relu和matmul函数
        X = tf.nn.relu(tf.matmul(X, self.rand_weight) + 1)
        # 复用连接层。这相当于两个全连接层共享参数。
        X = self.dense(X)
        # 控制流
        while tf.reduce_sum(tf.math.abs(X)) > 1:
            X /= 2
        return tf.reduce_sum(X)

class NestMLP(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.net = tf.keras.Sequential()
        self.net.add(tf.keras.layers.Dense(64, activation=tf.nn.relu))
        self.net.add(tf.keras.layers.Dense(32, activation=tf.nn.relu))
        self.dense = tf.keras.layers.Dense(16, activation=tf.nn.relu)

    def call(self, inputs):
        return self.dense(self.net(inputs))

if __name__ == '__main__':
    net = tf.keras.models.Sequential([
        tf.keras.layers.Dense(256, activation=tf.nn.relu),
        tf.keras.layers.Dense(10),
    ])
    X = tf.random.uniform((2, 20))
    # print("X.shape,", X.shape)
    # print("net(X).shape", net(X).shape)
    # print(net(X))

    net = MLP()
    # print("net(X).shape", net(X).shape)
    # print(net(X))

    net = MySequential(
        tf.keras.layers.Dense(units=256, activation=tf.nn.relu),
        tf.keras.layers.Dense(10)
    )
    # print(net(X))

    net = FixedHiddenMLP()
    # print(net(X))

    chimera = tf.keras.Sequential()
    chimera.add(NestMLP())
    chimera.add(tf.keras.layers.Dense(20))
    chimera.add(FixedHiddenMLP())
    # print(chimera(X))